# -*- encoding: utf-8 -*-

"""
@File:      tc_libtalloc_libtalloc_func_004.py
@Time:      2024/04/08 14:33:20
@Author:    chenchunhu
@Version:   1.0
@Contact:   wb-cch358909@alibaba-inc.com
@License:   Mulan PSL v2
@Modify:    chenchunhu
"""

from common.basetest import LocalTest

class Test(LocalTest):
    """
    See tc_libtalloc_libtalloc_func_004.yaml for details

    :avocado: tags=P2,noarch,local,fix
    """

    PARAM_DIC = {"pkg_name": "libtalloc libtalloc-devel gcc"}
    def setUp(self):
        super().setUp(self.PARAM_DIC)
        cmdline = '''cat >talloc_test4.c<<"EOF"
#include <talloc.h>
#include <assert.h>
#include <stdio.h>
#include <setjmp.h>
#include <signal.h>
#include <unistd.h>

// 定义一个全局的 jmp_buf 用于在错误处理函数中使用
jmp_buf jump_buffer;

// 自定义的错误处理函数
static void custom_talloc_error_handler(const char *reason)
{
    printf("Talloc error: %s\\n", reason);
    longjmp(jump_buffer, 1);
}

// 恢复默认行为退出
static void exit_with_error(const char *reason)
{
    printf("Exiting with error: %s\\n", reason);
    _exit(1);
}

void test_double_free() {
    if (setjmp(jump_buffer) == 0) {
        talloc_set_abort_fn(custom_talloc_error_handler);

        void *root = talloc_new(NULL);
        void *obj = talloc_size(root, 1024); // 假设分配1024字节
        talloc_free(obj);
        talloc_free(obj); // 尝试重复释放

        // 回到默认退出方式
        talloc_set_abort_fn(exit_with_error);

        printf("Double free test: success\\n");
    } else {
        printf("Double free test: caught expected error\\n");
    }
}

void test_invalid_pointer() {
    if (setjmp(jump_buffer) == 0) {
        talloc_set_abort_fn(custom_talloc_error_handler);

        void *root = talloc_new(NULL);
        int non_alloc = 123; // 此指针未由 talloc 分配内存
        talloc_free(&non_alloc); // 尝试释放非 talloc 分配内存

        // 回到默认退出方式
        talloc_set_abort_fn(exit_with_error);

        printf("Invalid pointer free test: success\\n");
    } else {
        printf("Invalid pointer free test: caught expected error\\n");
    }
}

int main() {
    test_double_free();
    test_invalid_pointer();

    printf("All tests passed.\\n");
    return 0;
}
EOF'''
        self.cmd(cmdline)

    def test(self):
        self.cmd("gcc -o talloc_test4 talloc_test4.c -ltalloc")
        code, talloc_result = self.cmd("./talloc_test4")
        self.assertIn('All tests passed', talloc_result)

    def tearDown(self):
        super().tearDown(self.PARAM_DIC)
        self.cmd("rm -rf talloc_test4.c talloc_test4")
